import torch
import heapq
import random
import pickle
import gzip
import numpy as np

import editdistance

class ReplayBuffer:
    """
    A relay buffer that uses a heap to keep the max_size items with the highest reward
    """

    def __init__(self, buffer_size, prb=True, sim_tolerance=0.25):
        self.buffer_size = buffer_size
        self.sim_tolerance = sim_tolerance
        self.prb = prb
        self.reset()

    def reset(self):
        self._buffer = {}

    def add(self, problem, plan, sample, log_reward):
        """
        add an item to the buffer, where item = [log reward, tensor of shape (seq_len, )]
        """
        # if the plans have already existed in the problem
        if problem not in self._buffer:
            self._buffer[problem] = {
                "sentences": [],
                "exists": set(),
            }
        if plan in self._buffer[problem]["exists"]:
            return
        # if the edit distance between item and any item in the buffer is small, skip it
        # tokenized_sentence = [
        #     x
        #     for x in item["tensor_sentence"].tolist()
        #     if x != self.termination_token_id
        # ]
        # for buffer_item in self._buffer[problem]["sentences"]:
            # tokenized_existing_sentence = [
            #     x for x in buffer_item[2].tolist() if x != self.termination_token_id
            # ]
            # if (
            #     editdistance.eval(tokenized_sentence, tokenized_existing_sentence)
            #     < (len(tokenized_sentence) + len(tokenized_existing_sentence))
            #     * self.sim_tolerance
            # ):
            
        heapq.heapify(self._buffer[problem]["sentences"])
        self._buffer[problem]["exists"].add(plan)
        heapq.heappush(
            self._buffer[problem]["sentences"],
            (
                log_reward,
                plan,
                sample
            ),
        )
            
        if len(self._buffer[problem]["sentences"]) > self.buffer_size:

            popped = heapq.heappop(self._buffer[problem]["sentences"])
            self._buffer[problem]["exists"].discard(popped[1])

    def sample(self, batch_size, problem):
        """
        uniformly sample a batch of items from the buffer,
        and return a stacked tensor
        """
        # str_prompt = " ".join([str(x) for x in prompt.tolist()])
        if problem not in self._buffer:
            return None, None, None
        prompt_buffer = self._buffer[problem]["sentences"]
        # sorted_buffer = sorted(prompt_buffer, key=lambda x: x[0])
        # idx_list = np.arange(len(prompt_buffer))
        
        if self.prb:

            priorities  = [item[0] for item in prompt_buffer]
            priorities = torch.tensor(priorities, dtype=torch.float32)  # 确保priorities是float类型
            priorities = torch.exp(priorities)
            priorities = priorities - torch.max(priorities)  # 从每个元素中减去最大值以增加数值稳定性

            # 计算概率分布
            probabilities = torch.exp(priorities) / torch.sum(torch.exp(priorities))


            idx = torch.multinomial(probabilities, batch_size, replacement=True)
        else:
            idx = np.random.choice(
                len(prompt_buffer),
                batch_size,
                replace=True,
            )
        return [prompt_buffer[i][0] for i in idx], [prompt_buffer[i][1] for i in idx], [prompt_buffer[i][2] for i in idx],

    def print(self):
        for key in self._buffer:
            print(key)
            for item in self._buffer[key]["sentences"]:
                print(item[1])
            print("")

    def save(self, path):
        with gzip.open(path, "wb") as f:
            pickle.dump(self._buffer, f)